# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any, TypeAlias

from pydantic import BaseModel, Field, field_validator

from beeai_framework.backend.utils import parse_broken_json


class ChatToolFunctionDefinition(BaseModel):
    name: str
    arguments: dict[str, Any]

    @classmethod
    @field_validator("arguments", mode="before")
    def parse_arguments(cls, value: Any) -> Any:
        if isinstance(value, str):
            try:
                value = parse_broken_json(value)
            except json.JSONDecodeError:
                raise ValueError("Invalid JSON format for arguments")
        return value


class ChatToolCall(BaseModel):
    id: str
    function: ChatToolFunctionDefinition
    type: str


class TextContentPart(BaseModel):
    text: str = Field(..., description="The text content.")
    type: str = Field(
        "text",
        description="The type of the content part.",
    )


class Image(BaseModel):
    url: str = Field(..., description="Either a URL of the image or the base64 encoded image data.")
    detail: str | None = Field(..., description="Specifies the detail level of the image.")


class ImageContentPart(BaseModel):
    image_url: Image = Field(...)
    type: str = Field(
        "input_image",
        description="The type of the content part.",
    )


class Audio(BaseModel):
    data: str = Field(..., description="Base64 encoded audio data.")
    format: str = Field(..., description="The format of the encoded audio data.")


class AudioContentPart(BaseModel):
    input_audio: Audio = Field(...)
    type: str = Field(
        "input_audio",
        description="The type of the content part. Always 'input_audio'.",
    )


class File(BaseModel):
    file_data: str | None = Field(
        None, description="The base64 encoded file data, used when passing the file to the model as a string."
    )
    file_id: str | None = Field(None, description="The ID of an uploaded file to use as input.")
    filename: str | None = Field(
        None, description="The name of the file, used when passing the file to the model as a string."
    )


class FileContentPart(BaseModel):
    file: File = Field(...)
    type: str = Field(
        "file",
        description="The type of the content part. Always 'file'.",
    )


class RefusalContentPart(BaseModel):
    refusal: str = Field(..., description="The refusal message generated by the model.")
    type: str = Field(
        "refusal",
        description="The type of the content part.",
    )


class DeveloperMessage(BaseModel):
    role: str = Field(
        "developer",
        description="The role of the messages author, in this case 'developer'.",
    )
    name: str | None = Field(
        None,
        description="An optional name for the participant. "
        "Provides the model information to differentiate between participants of the same role.",
    )
    content: str | list[TextContentPart] = Field(
        ...,
        description="The contents of the developer message.",
    )


class SystemMessage(BaseModel):
    role: str = Field(
        "system",
        description="The role of the messages author, in this case 'system'.",
    )
    name: str | None = Field(
        None,
        description="An optional name for the participant. "
        "Provides the model information to differentiate between participants of the same role.",
    )
    content: str | list[TextContentPart] = Field(
        ...,
        description="The contents of the system message.",
    )


class UserMessage(BaseModel):
    role: str = Field(
        "user",
        description="The role of the messages author, in this case 'user'.",
    )
    name: str | None = Field(
        None,
        description="An optional name for the participant. "
        "Provides the model information to differentiate between participants of the same role.",
    )
    content: str | list[TextContentPart | ImageContentPart | AudioContentPart | FileContentPart] = Field(
        ...,
        description="The contents of the user message.",
    )


class AssistantMessage(BaseModel):
    role: str = Field(
        "assistant",
        description="The role of the messages author, in this case 'assistant'.",
    )
    name: str | None = Field(
        None,
        description="An optional name for the participant. "
        "Provides the model information to differentiate between participants of the same role.",
    )
    content: str | list[TextContentPart | RefusalContentPart] | None = Field(
        None,
        description="The contents of the user message.",
    )
    refusal: str | None = Field(
        None,
        description="The refusal message by the assistant.",
    )
    tool_calls: list[ChatToolCall] | None = Field(
        None, description="The tool calls generated by the model, such as function calls."
    )


class ToolMessage(BaseModel):
    role: str = Field("tool", description="The role of the messages author, in this case 'tool'.")
    content: str | list[TextContentPart] = Field(
        ...,
        description="The content of the message. It can be null if no content is provided.",
    )
    tool_call_id: str = Field(
        ...,
        description="Tool call that this message is responding to.",
    )


ContentPart: TypeAlias = TextContentPart | AudioContentPart | FileContentPart | RefusalContentPart | ImageContentPart
ChatMessage: TypeAlias = UserMessage | DeveloperMessage | SystemMessage | ToolMessage | AssistantMessage


class ChatCompletionRequestBody(BaseModel):
    model: str = Field(description="ID of the model to use. If not provided, a default model will be used")
    messages: list[ChatMessage] = Field(..., description="List of messages in the conversation")
    stream: bool | None = Field(False, description="Whether to stream responses as server-sent events")


class ChatMessageResponse(BaseModel):
    role: str = Field(..., description="The role of the message sender", pattern="^(user|assistant)$")
    content: str = Field(..., description="The content of the message")


class ChatCompletionChoice(BaseModel):
    index: int = Field(..., description="The index of the choice")
    message: ChatMessageResponse = Field(..., description="The message")
    finish_reason: str | None = Field(None, description="The reason the message generation finished")


class ChatCompletionUsage(BaseModel):
    prompt_tokens: int = Field(..., description="Number of prompt tokens")
    completion_tokens: int = Field(..., description="Number of generated tokens")
    total_tokens: int = Field(..., description="Number of total tokens")


class ChatCompletionResponse(BaseModel):
    id: str = Field(..., description="Unique identifier for the completion")
    object: str = Field(
        "chat.completion",
        description="The type of object returned, should be 'chat.completion'",
    )
    created: int = Field(..., description="Timestamp of when the completion was created")
    model: str = Field(..., description="The model used for generating the completion")
    choices: list[ChatCompletionChoice] = Field(..., description="List of completion choices")
    usage: ChatCompletionUsage | None = Field(None, description="The usage of the completion")
